#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/4/8 20:26
# @Author  : Hbber
# @Version : 0.0.0
# @File    : tools.py
# @Software: PyCharm
# @Org     : AirCas
# @Describe: 工具函数

# @Last Modify Time          @Version        @Description
# --------------------       --------        -----------
# 2021/4/8 20:26            0.0.1           None

import hashlib
import re
import struct
import time
from typing import Optional, Union

import serial
from serial import SerialException

from .my_logging import getLogger
from .enum_constants import *

logger = getLogger(__name__)


def get_timestamp_ms() -> int:
    return round(time.time() * 1000)


def md5(data: bytes):
    md5hash = hashlib.md5(data)
    return md5hash.hexdigest()


def check_port(port: str) -> bool:
    """
    检查串口名称合法性
    :param port:
    :type port:
    :return:
    :rtype:
    """
    window_type = re.match("^COM[0-9]+$", port)
    linux_type = re.match("^/dev/ttyS[0-9]+$", port)
    if window_type is None and linux_type is None:
        return False
    else:
        return True


def checksum(payload: Union[bytes, bytearray, list, int]) -> bytes:
    """
    每个UART命令中的8位校验和是命令有效负载的所有字节的无符号值的简单无符号和，
    其中只保留最低有效的8位。
    例如，计算校验和的伪代码如下:
    checksum = 0
        for each byte in the payload, checksum = (checksum + (unsigned) byte) AND (0xFF)
    :param payload: 串口数据包的有效负载
    :type payload: bytearray
    :return: 校验和
    :rtype: bytes
    """
    if not (isinstance(payload, bytes) or isinstance(payload, bytearray) or isinstance(payload, list) or isinstance(
            payload, int)):
        raise TypeError('需要bytearray或list: int或int类型的参数，不支持：{}'.format(type(payload)))
    if isinstance(payload, list):
        if len(payload) < 1 or len(payload) > 240:
            raise TypeError('数组的长度至少为1, 最大240')
        try:
            payload = bytearray(payload)
        except Exception as e:
            raise TypeError('不支持的数组：{}'.format(payload))
    if isinstance(payload, int):
        try:
            payload = bytes([payload])
            return payload
        except Exception as e:
            raise TypeError('不支持int：{}'.format(payload))
    ck_sum = 0
    for bt in payload:
        ck_sum = (ck_sum + bt) & 0xff
    return bytes([ck_sum])


def is_ack_res(data_recv: bytes):
    ACK_RES = bytes([0x00, 0x04, 0xCC, 0x00, 0xCC])
    if len(data_recv) != 5:
        return False
    if data_recv.hex() == ACK_RES.hex():
        return True
    else:
        return False


class CMD_PKG(object):
    PING_COMMAND = bytes([0xAA, 0x00, 0x03, 0x20, 0x20])
    ERASE_COMMAND = bytes([0xAA, 0x00, 0x03, 0x28, 0x28])
    GET_STATUS_COMMAND = bytes([0xAA, 0x00, 0x03, 0x23, 0x23])
    GET_VERSION_COMMAND = bytes([0xAA, 0x00, 0x03, 0x2F, 0x2F])

    def __init__(self, cmd_type: CMD, payload: bytearray):
        if not (isinstance(cmd_type, CMD) and isinstance(payload, bytearray)):
            raise TypeError('need cmd_type: CMD, payload: bytearray, but cmd_type: {0}, payload: {1}'.
                            format(type(cmd_type), type(payload)))
        self.cmd_type = cmd_type
        self.payload = payload
        self.length = len(payload) + 3
        self.real_payload = bytearray([self.cmd_type.value]) + self.payload
        self.ck_sum = checksum(self.real_payload)

    def get_uart_pkg(self):
        length = struct.pack('>H', self.length)
        uart_pkg = bytes([SYNC_HEADER]) + length + self.ck_sum + self.real_payload
        return uart_pkg

    @classmethod
    def creat_open_cmd(cls, file_size: int, storage_type: STORAGE_TYPE, file_type: FILE_TYPE):
        payload = struct.pack('>IIII', file_size, storage_type.value, file_type.value, 0)
        payload = bytearray(payload)
        return cls(cmd_type=CMD.OPEN_FILE, payload=payload)

    @classmethod
    def creat_write_to_flash_cmd(cls, data: bytes):
        if len(data) > 240:
            raise TypeError('单次写入数据需小于240字节，实际data长度：{}'.format(len(data)))
        return cls(cmd_type=CMD.WRITE_FILE_TO_SFLASH, payload=bytearray(data))

    @classmethod
    def creat_write_to_ram_cmd(cls, data: bytes):
        if len(data) > 240:
            raise TypeError('单次写入数据需小于240字节，实际data长度：{}'.format(len(data)))
        return cls(cmd_type=CMD.WRITE_FILE_TO_RAM, payload=bytearray(data))

    @classmethod
    def creat_close_cmd(cls, storage_type: STORAGE_TYPE):
        payload = struct.pack('>I', storage_type.value)
        payload = bytearray(payload)
        return cls(cmd_type=CMD.CLOSE_FILE, payload=payload)


class Flasher(object):
    """
    TI毫米波雷达程序下载器
    """
    SUPPORT_DEVICE_LIST = [DEVICE.IWR6843]
    DEMO_IMG = {
        DEVICE.IWR6843: "../resources/xwr68xx_mmw_demo.bin"
    }

    def __init__(self,
                 uart_port: str = IWR6843_ISK_PORT,
                 device_type: DEVICE = DEVICE.IWR6843):
        """
        实例化一个下载器，需要传入连接的串口号和器件类型
        :param uart_port: 串口号，如: "COMx"  "/dev/ttySx"
        :type uart_port: str
        :param device_type: 器件类型，目前仅支持IWR6843
        :type device_type: DEVICE
        """
        if not (isinstance(uart_port, str) and isinstance(device_type, DEVICE)):
            raise TypeError('输入参数类型错误，需要uart_port: str, device_type: DEVICE，'
                            '但是uart_port: {}, device_type: {}'.format(type(uart_port), type(device_type)))
        if check_port(uart_port) is False:
            raise TypeError('串口号不合规：{}，需要类似于："COM1" "/dev/ttyS1"'.format(uart_port))
        if device_type not in self.SUPPORT_DEVICE_LIST:
            raise TypeError('不支持的器件：{}，目前支持的器件列表是：{}'.format(device_type.value, self.SUPPORT_DEVICE_LIST))
        self.uart_port = uart_port
        self.device_type = device_type
        self.attached = False
        # 创建并打开串口
        try:
            self._serial = serial.Serial(port=self.uart_port, baudrate=BAUD_RATE, timeout=1)
            # ping检测
            self.attached = self.__ping()
            if self.attached:
                logger.info('{}连接成功'.format(device_type.name))
                time.sleep(0.5)
            else:
                logger.error('{}连接失败, 请确认芯片SOP2为High！'.format(device_type.name))
                raise RuntimeError(f'{device_type.name}连接失败')
        except SerialException as e:
            logger.error('打开串口失败:{}'.format(str(e)))
            raise e

    def __ping(self, timeout: float = 10):
        """
        发送break信号和ping信号，检测器件是否回复ack
        :return:
        :rtype:
        """
        self._serial.send_break()

        time.sleep(0.2)
        self._serial.write(CMD_PKG.PING_COMMAND)

        self._serial.timeout = timeout
        data_recv = self._serial.read(size=5)
        self._serial.timeout = 1

        return is_ack_res(data_recv)

    def erase(self, timeout: float = 300) -> Optional[bool]:
        """
        清空flash，默认超时时间300秒
        :param timeout:
        :type timeout:
        :return:
        :rtype:
        """
        if self.attached is False:
            logger.error('{}未连接！'.format(self.device_type.name))
            return None

        self._serial.flushInput()

        self._serial.write(CMD_PKG.ERASE_COMMAND)

        self._serial.timeout = timeout
        data_recv = self._serial.read(size=5)
        self._serial.timeout = 1

        # logger.debug('erase res-> data_recv(hex):{}'.format(data_recv.hex()))

        return is_ack_res(data_recv)

    def get_status(self, timeout: float = 10) -> Optional[STATUS_RESPONSE]:
        """
        获取当前状态:
            INITIAL_STATUS = 0x00
            SUCCESS = 0x40
            ACCESS_IN_PROGRESS = 0x4B
        :param timeout:
        :type timeout:
        :return:
        :rtype:
        """
        if self.attached is False:
            logger.error('{}未连接！'.format(self.device_type.name))
            return None

        self._serial.flushInput()

        self._serial.write(CMD_PKG.GET_STATUS_COMMAND)

        self._serial.timeout = timeout
        data_recv = self._serial.read(size=4)
        self._serial.timeout = 1

        if len(data_recv) == 4:
            length, ck_sum, status = struct.unpack('>HBB', data_recv)
            if length == 3 and (ck_sum == status):
                try:
                    return STATUS_RESPONSE(status)
                except ValueError as e:
                    logger.error('未知的STATUS_RESPONSE：{}'.format(status))
                    return None
            else:
                logger.error('数据包校验错误！数据包内容(hex):{}'.format(data_recv.hex()))
                return None
        else:
            logger.error('读取STATUS_RESPONSE超时，已收到的字节(hex):{}'.format(data_recv.hex()))
            return None

    def get_version(self, timeout: float = 10) -> Optional[str]:
        """
        获取芯片的硬件版本信息，返回rom_version_information
        :param timeout:
        :type timeout:
        :return:
        :rtype: str
        """
        if self.attached is False:
            logger.error('{}未连接！'.format(self.device_type.name))
            return None

        self._serial.flushInput()

        self._serial.write(CMD_PKG.GET_VERSION_COMMAND)

        # 首先接收ack响应
        self._serial.timeout = timeout
        data_recv = self._serial.read(size=5)
        self._serial.timeout = 1

        if not is_ack_res(data_recv):
            logger.error('读取VERSION失败，未收到ack消息:{}'.format(data_recv.hex()))
            return None

        # 接收VERSION响应
        self._serial.timeout = timeout
        data_recv = self._serial.read(size=15)
        self._serial.timeout = 1

        if len(data_recv) == 15:
            length = struct.unpack('>H', data_recv[0:2])[0]
            if length == 14 and (data_recv[2:3] == checksum(data_recv[3:])):
                return data_recv[3:7].hex()
            else:
                logger.debug('data_recv[2:3]={}, checksum(data_recv[3:])={}'.format(data_recv[2:3].hex(),
                                                                                    checksum(data_recv[3:]).hex()))
                logger.error('数据包校验错误，原数据包信息(hex):{}, length:{}'.format(data_recv.hex(), length))
        else:
            logger.error('读取VERSION超时，已收到字节(hex):{}'.format(data_recv.hex()))

    def download(self,
                 file_path: str = None,
                 timeout: float = 300,
                 storage_type: STORAGE_TYPE = STORAGE_TYPE.SFLASH,
                 file_type: FILE_TYPE = FILE_TYPE.META_IMG_1,
                 show_progress_bars: bool = True,
                 erase: bool = False) -> Optional[bool]:
        """
        下载固件到flash或ram\n
        当前版本仅测试了写入SFLASH的META_IMG_1，对其他情况不保证，慎用！\n
        :param file_path: bin文件地址，如果未指定则下载demo固件
        :type file_path: str
        :param timeout: 超时时间，默认5分钟
        :type timeout: float
        :param file_type: 选择将固件写入第几分区，默认META_IMG_1
        :type file_type: FILE_TYPE
        :param storage_type: 选择将固件写入flash还是ram，默认写入flash
        :type storage_type: STORAGE_TYPE
        :param show_progress_bars: 是否显示进度条，默认显示
        :type show_progress_bars: bool
        :param erase: 下载前是否先擦除，默认False
        :type erase: bool
        :return: success or fail
        :rtype: bool
        """

        def _pr(info: str):
            if show_progress_bars:
                print(info)

        if self.attached is False:
            logger.error('{}未连接！'.format(self.device_type.name))
            return None
        if file_path is None:
            file_path = self.DEMO_IMG[self.device_type]
        if timeout < 300:
            timeout = 300
        # TODO: 第一版仅测试了写入SFLASH的META_IMG_1，对其他情况不保证
        if storage_type != STORAGE_TYPE.SFLASH or file_type != FILE_TYPE.META_IMG_1:
            logger.warning('当前版本仅测试了写入SFLASH的META_IMG_1，对其他情况不保证，慎用！')

        logger.debug('加载bin文件"{}"到内存'.format(file_path))
        try:
            with open(file_path, 'rb') as f:
                bin_file = f.read()
            logger.debug('"{}"加载完成'.format(file_path))
        except FileNotFoundError as e:
            logger.error('文件打开出错，请检查:{}'.format(e))
            return False
        if erase:
            _pr('开始擦除flash...')
            if self.erase():
                _pr('擦除成功')
            else:
                _pr('擦除失败，程序退出')
                return False

        # 参数计算
        file_size = len(bin_file)
        N = int(file_size / 240) + (1 if (file_size % 240) else 0)
        if show_progress_bars:
            from tqdm import tqdm
        _pr('开始下载"{}"到{}, 固件大小:{}KB, 预计下载耗时{}秒'.
            format(file_path, self.device_type.name, int(file_size / 1024), int(file_size / 14400)))

        # 1. send OPEN command
        open_cmd = CMD_PKG.creat_open_cmd(file_size=file_size, storage_type=storage_type, file_type=file_type)
        time.sleep(0.25)
        start_time = get_timestamp_ms()
        logger.debug('发送OPEN Command at: {}'.format(start_time))
        self._serial.flushInput()
        num = self._serial.write(open_cmd.get_uart_pkg())
        logger.debug('open cmd写入{}字节'.format(num))

        # 2. 接收ack响应
        self._serial.timeout = 10
        data_recv = self._serial.read(size=5)

        if is_ack_res(data_recv):
            logger.debug('接收到OPEN Command的ack, 耗时:{}ms'.format(get_timestamp_ms() - start_time))
        else:
            logger.error('接收OPEN Command的ack失败:{}'.format(data_recv.hex()))
            return False

        # 3. 循环发送bin文件
        logger.debug('开始发送bin文件...')
        self._serial.timeout = 10
        if show_progress_bars:
            pbar = tqdm(total=N)
        for i in range(N):
            begin_ptr = i * 240
            end_ptr = (i * 240 + 240) if (i != N - 1) else file_size
            if storage_type == STORAGE_TYPE.SFLASH:
                write_pkg = CMD_PKG.creat_write_to_flash_cmd(data=bin_file[begin_ptr:end_ptr])
                self._serial.write(write_pkg.get_uart_pkg())
                # print(f'N: {N}; i: {i}; begin: {begin_ptr}; end: {end_ptr}; len: {end_ptr-begin_ptr}; data: {write_pkg.get_uart_pkg().hex()}')
            else:
                write_pkg = CMD_PKG.creat_write_to_ram_cmd(data=bin_file[begin_ptr:end_ptr])
                self._serial.write(write_pkg.get_uart_pkg())
            recv_data = self._serial.read(size=5)
            if not is_ack_res(recv_data):
                _pr('发生错误! 总包数:{}，正在发送第{}包，发生错误的包:{}'.
                    format(N, i + 1, bin_file[begin_ptr:end_ptr]))
                return False
            if show_progress_bars:
                pbar.update(1)

        if show_progress_bars:
            pbar.close()

        # 4. 发送Close Command
        close_cmd = CMD_PKG.creat_close_cmd(storage_type=storage_type)
        logger.debug('发送CLOSE Command: {}'.format(close_cmd.get_uart_pkg().hex()))
        # self._serial.flushInput()
        self._serial.write(close_cmd.get_uart_pkg())

        # 5. 接收ack响应
        self._serial.timeout = 10
        data_recv = self._serial.read(size=5)

        if is_ack_res(data_recv):
            logger.debug('接收到CLOSE Command的ack')
        else:
            logger.error('接收CLOSE Command的ack失败:{}'.format(data_recv.hex()))
            _pr('CLOSE Command 执行失败, 请重试, ack:{}'.
                format(data_recv.hex() if (len(data_recv) != 0) else "未收到"))
            return False
        _pr('固件"{}"下载成功, 耗时:{}秒, 请将SOP2拉低，复位芯片以启动新的程序！'.
            format(file_path, (get_timestamp_ms() - start_time) / 1000, ))
        logger.info('为芯片: {}更新固件: "{}"成功, 耗时:{}秒, MD5:{}'.
                    format(self.device_type.name, file_path, (get_timestamp_ms() - start_time) / 1000, md5(bin_file)))
